{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 0
   },
   "source": [
    "# Adam\n",
    ":label:`sec_adam`\n",
    "\n",
    "In the discussions leading up to this section we encountered a number of techniques for efficient optimization. Let us recap them in detail here:\n",
    "\n",
    "* We saw that :numref:`sec_sgd` is more effective than Gradient Descent when solving optimization problems, e.g., due to its inherent resilience to redundant data. \n",
    "* We saw that :numref:`sec_minibatch_sgd` affords significant additional efficiency arising from vectorization, using larger sets of observations in one minibatch. This is the key to efficient multi-machine, multi-GPU and overall parallel processing. \n",
    "* :numref:`sec_momentum` added a mechanism for aggregating a history of past gradients to accelerate convergence.\n",
    "* :numref:`sec_adagrad` used per-coordinate scaling to allow for a computationally efficient preconditioner. \n",
    "* :numref:`sec_rmsprop` decoupled per-coordinate scaling from a learning rate adjustment. \n",
    "\n",
    "Adam :cite:`Kingma.Ba.2014` combines all these techniques into one efficient learning algorithm. As expected, this is an algorithm that has become rather popular as one of the more robust and effective optimization algorithms to use in deep learning. It is not without issues, though. In particular, :cite:`Reddi.Kale.Kumar.2019` show that there are situations where Adam can diverge due to poor variance control. In a follow-up work :cite:`Zaheer.Reddi.Sachan.ea.2018` proposed a hotfix to Adam, called Yogi which addresses these issues. More on this later. For now let us review the Adam algorithm. \n",
    "\n",
    "## The Algorithm\n",
    "\n",
    "One of the key components of Adam is that it uses exponential weighted moving averages (also known as leaky averaging) to obtain an estimate of both the momentum and also the second moment of the gradient. That is, it uses the state variables\n",
    "\n",
    "$$\\begin{aligned}\n",
    "    \\mathbf{v}_t & \\leftarrow \\beta_1 \\mathbf{v}_{t-1} + (1 - \\beta_1) \\mathbf{g}_t, \\\\\n",
    "    \\mathbf{s}_t & \\leftarrow \\beta_2 \\mathbf{s}_{t-1} + (1 - \\beta_2) \\mathbf{g}_t^2.\n",
    "\\end{aligned}$$\n",
    "\n",
    "Here $\\beta_1$ and $\\beta_2$ are nonnegative weighting parameters. Common choices for them are $\\beta_1 = 0.9$ and $\\beta_2 = 0.999$. That is, the variance estimate moves *much more slowly* than the momentum term. Note that if we initialize $\\mathbf{v}_0 = \\mathbf{s}_0 = 0$ we have a significant amount of bias initially towards smaller values. This can be addressed by using the fact that $\\sum_{i=0}^t \\beta^i = \\frac{1 - \\beta^t}{1 - \\beta}$ to re-normalize terms. Correspondingly the normalized state variables are given by \n",
    "\n",
    "$$\\hat{\\mathbf{v}}_t = \\frac{\\mathbf{v}_t}{1 - \\beta_1^t} \\text{ and } \\hat{\\mathbf{s}}_t = \\frac{\\mathbf{s}_t}{1 - \\beta_2^t}.$$\n",
    "\n",
    "Armed with the proper estimates we can now write out the update equations. First, we rescale the gradient in a manner very much akin to that of RMSProp to obtain\n",
    "\n",
    "$$\\mathbf{g}_t' = \\frac{\\eta \\hat{\\mathbf{v}}_t}{\\sqrt{\\hat{\\mathbf{s}}_t} + \\epsilon}.$$\n",
    "\n",
    "Unlike RMSProp our update uses the momentum $\\hat{\\mathbf{v}}_t$ rather than the gradient itself. Moreover, there is a slight cosmetic difference as the rescaling happens using $\\frac{1}{\\sqrt{\\hat{\\mathbf{s}}_t} + \\epsilon}$ instead of $\\frac{1}{\\sqrt{\\hat{\\mathbf{s}}_t + \\epsilon}}$. The former works arguably slightly better in practice, hence the deviation from RMSProp. Typically we pick $\\epsilon = 10^{-6}$ for a good trade-off between numerical stability and fidelity. \n",
    "\n",
    "Now we have all the pieces in place to compute updates. This is slightly anticlimactic and we have a simple update of the form\n",
    "\n",
    "$$\\mathbf{x}_t \\leftarrow \\mathbf{x}_{t-1} - \\mathbf{g}_t'.$$\n",
    "\n",
    "Reviewing the design of Adam its inspiration is clear. Momentum and scale are clearly visible in the state variables. Their rather peculiar definition forces us to debias terms (this could be fixed by a slightly different initialization and update condition). Second, the combination of both terms is pretty straightforward, given RMSProp. Last, the explicit learning rate $\\eta$ allows us to control the step length to address issues of convergence. \n",
    "\n",
    "## Implementation \n",
    "\n",
    "Implementing Adam from scratch is not very daunting. For convenience we store the timestep counter $t$ in the `hyperparams` dictionary. Beyond that all is straightforward.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load ../utils/djl-imports\n",
    "%load ../utils/plot-utils\n",
    "%load ../utils/Functions.java\n",
    "%load ../utils/GradDescUtils.java\n",
    "%load ../utils/Accumulator.java\n",
    "%load ../utils/StopWatch.java\n",
    "%load ../utils/Training.java\n",
    "%load ../utils/TrainingChapter11.java"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "NDList initAdamStates(int featureDimension) {\n",
    "    NDManager manager = NDManager.newBaseManager();\n",
    "    NDArray vW = manager.zeros(new Shape(featureDimension, 1));\n",
    "    NDArray vB = manager.zeros(new Shape(1));\n",
    "    NDArray sW = manager.zeros(new Shape(featureDimension, 1));\n",
    "    NDArray sB = manager.zeros(new Shape(1));\n",
    "    return new NDList(vW, sW, vB, sB);\n",
    "}\n",
    "\n",
    "public class Optimization {\n",
    "    public static void adam(NDList params, NDList states, Map<String, Float> hyperparams) {\n",
    "        float beta1 = 0.9f;\n",
    "        float beta2 = 0.999f;\n",
    "        float eps = (float) 1e-6;\n",
    "        float time = hyperparams.get(\"time\");\n",
    "        for (int i = 0; i < params.size(); i++) {\n",
    "            NDArray param = params.get(i);\n",
    "            NDArray velocity = states.get(2 * i);\n",
    "            NDArray state = states.get(2 * i + 1);\n",
    "            // Update parameter, velocity, and state\n",
    "            // velocity = beta1 * v + (1 - beta1) * param.gradient\n",
    "            velocity.muli(beta1).addi(param.getGradient().mul(1 - beta1));\n",
    "            // state = beta2 * state + (1 - beta2) * param.gradient^2\n",
    "            state.muli(beta2).addi(param.getGradient().square().mul(1 - beta2));\n",
    "            // vBiasCorr = velocity / ((1 - beta1)^(time))\n",
    "            NDArray vBiasCorr = velocity.div(1 - Math.pow(beta1, time));\n",
    "            // sBiasCorr = state / ((1 - beta2)^(time))\n",
    "            NDArray sBiasCorr = state.div(1 - Math.pow(beta2, time));\n",
    "            // param -= lr * vBiasCorr / (sBiasCorr^(1/2) + eps)\n",
    "            param.subi(vBiasCorr.mul(hyperparams.get(\"lr\")).div(sBiasCorr.sqrt().add(eps)));\n",
    "        }\n",
    "        hyperparams.put(\"time\", time + 1);\n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 2
   },
   "source": [
    "We are ready to use Adam to train the model. We use a learning rate of $\\eta = 0.01$.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "AirfoilRandomAccess airfoil = TrainingChapter11.getDataCh11(10, 1500);\n",
    "\n",
    "public TrainingChapter11.LossTime trainAdam(float lr, float time, int numEpochs) throws IOException, TranslateException {\n",
    "    int featureDimension = airfoil.getColumnNames().size();\n",
    "    Map<String, Float> hyperparams = new HashMap<>();\n",
    "    hyperparams.put(\"lr\", lr);\n",
    "    hyperparams.put(\"time\", time);\n",
    "    return TrainingChapter11.trainCh11(Optimization::adam, \n",
    "                                       initAdamStates(featureDimension), \n",
    "                                       hyperparams, airfoil, \n",
    "                                       featureDimension, numEpochs);\n",
    "}\n",
    "\n",
    "TrainingChapter11.LossTime lossTime = trainAdam(0.01f, 1, 2);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 4
   },
   "source": [
    "A more concise implementation is straightforward since `adam` is one of the algorithms provided as part of the DJL optimization library.\n",
    "\n",
    "We will set the learning rate to 0.01f to remain consistent with the previous section.\n",
    "However, you typically won't need to set this yourself as the default will usually work fine."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Tracker lrt = Tracker.fixed(0.01f);\n",
    "Optimizer adam = Optimizer.adam().optLearningRateTracker(lrt).build();\n",
    "\n",
    "TrainingChapter11.trainConciseCh11(adam, airfoil, 2);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 6
   },
   "source": [
    "## Yogi\n",
    "\n",
    "One of the problems of Adam is that it can fail to converge even in convex settings when the second moment estimate in $\\mathbf{s}_t$ blows up. As a fix :cite:`Zaheer.Reddi.Sachan.ea.2018` proposed a refined update (and initialization) for $\\mathbf{s}_t$. To understand what's going on, let us rewrite the Adam update as follows:\n",
    "\n",
    "$$\\mathbf{s}_t \\leftarrow \\mathbf{s}_{t-1} + (1 - \\beta_2) \\left(\\mathbf{g}_t^2 - \\mathbf{s}_{t-1}\\right).$$\n",
    "\n",
    "Whenever $\\mathbf{g}_t^2$ has high variance or updates are sparse, $\\mathbf{s}_t$ might forget past values too quickly. A possible fix for this is to replace $\\mathbf{g}_t^2 - \\mathbf{s}_{t-1}$ by $\\mathbf{g}_t^2 \\odot \\mathop{\\mathrm{sgn}}(\\mathbf{g}_t^2 - \\mathbf{s}_{t-1})$. Now the magnitude of the update no longer depends on the amount of deviation. This yields the Yogi updates\n",
    "\n",
    "$$\\mathbf{s}_t \\leftarrow \\mathbf{s}_{t-1} + (1 - \\beta_2) \\mathbf{g}_t^2 \\odot \\mathop{\\mathrm{sgn}}(\\mathbf{g}_t^2 - \\mathbf{s}_{t-1}).$$\n",
    "\n",
    "The authors furthermore advise to initialize the momentum on a larger initial batch rather than just initial pointwise estimate. We omit the details since they are not material to the discussion and since even without this convergence remains pretty good.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "public class Optimization {\n",
    "    public static void yogi(NDList params, NDList states, Map<String, Float> hyperparams) {\n",
    "        float beta1 = 0.9f;\n",
    "        float beta2 = 0.999f;\n",
    "        float eps = (float) 1e-3;\n",
    "        float time = hyperparams.get(\"time\");\n",
    "        for (int i = 0; i < params.size(); i++) {\n",
    "            NDArray param = params.get(i);\n",
    "            NDArray velocity = states.get(2 * i);\n",
    "            NDArray state = states.get(2 * i + 1);\n",
    "            // Update parameter, velocity, and state\n",
    "            // velocity = beta1 * v + (1 - beta1) * param.gradient\n",
    "            velocity.muli(beta1).addi(param.getGradient().mul(1 - beta1));\n",
    "            /* Rewritten Update */\n",
    "            // state = state + (1 - beta2) * sign(param.gradient^2 - state) \n",
    "            //         * param.gradient^2\n",
    "            state.addi(param.getGradient().square().sub(state).sign().mul(1 - beta2));\n",
    "            // vBiasCorr = velocity / ((1 - beta1)^(time))\n",
    "            NDArray vBiasCorr = velocity.div(1 - Math.pow(beta1, time));\n",
    "            // sBiasCorr = state / ((1 - beta2)^(time))\n",
    "            NDArray sBiasCorr = state.div(1 - Math.pow(beta2, time));\n",
    "            // param -= lr * vBiasCorr / (sBiasCorr^(1/2) + eps)\n",
    "            param.subi(vBiasCorr.mul(hyperparams.get(\"lr\")).div(sBiasCorr.sqrt().add(eps)));\n",
    "        }\n",
    "        hyperparams.put(\"time\", time + 1);\n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "AirfoilRandomAccess airfoil = TrainingChapter11.getDataCh11(10, 1500);\n",
    "\n",
    "public TrainingChapter11.LossTime trainYogi(float lr, float time, int numEpochs) throws IOException, TranslateException {\n",
    "    int featureDimension = airfoil.getColumnNames().size();\n",
    "    Map<String, Float> hyperparams = new HashMap<>();\n",
    "    hyperparams.put(\"lr\", lr);\n",
    "    hyperparams.put(\"time\", time);\n",
    "    return TrainingChapter11.trainCh11(Optimization::yogi, \n",
    "                                       initAdamStates(featureDimension), \n",
    "                                       hyperparams, airfoil, \n",
    "                                       featureDimension, numEpochs);\n",
    "}\n",
    "\n",
    "trainYogi(0.01f, 1, 2);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 8
   },
   "source": [
    "## Summary\n",
    "\n",
    "* Adam combines features of many optimization algorithms into a fairly robust update rule. \n",
    "* Created on the basis of RMSProp, Adam also uses EWMA on the minibatch stochastic gradient\n",
    "* Adam uses bias correction to adjust for a slow startup when estimating momentum and a second moment. \n",
    "* For gradients with significant variance we may encounter issues with convergence. They can be amended by using larger minibatches or by switching to an improved estimate for $\\mathbf{s}_t$. Yogi offers such an alternative. \n",
    "\n",
    "## Exercises\n",
    "\n",
    "1. Adjust the learning rate and observe and analyze the experimental results.\n",
    "1. Can you rewrite momentum and second moment updates such that it does not require bias correction?\n",
    "1. Why do you need to reduce the learning rate $\\eta$ as we converge?\n",
    "1. Try to construct a case for which Adam diverges and Yogi converges?\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Java",
   "language": "java",
   "name": "java"
  },
  "language_info": {
   "codemirror_mode": "java",
   "file_extension": ".jshell",
   "mimetype": "text/x-java-source",
   "name": "Java",
   "pygments_lexer": "java",
   "version": "14.0.2+12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
